背景介绍
WGAN(Wasserstein Generative Adversarial Networks):于2017年提出,和LSGAN类似,没有对网络结构做太多修改,分析了GAN网络中判别器效果越好,生成器梯度消失越严重的问题,而且提出了一种新的损失函数,构建了一个更加稳定,收敛更快,质量更高的生成式对抗网络。
WGAN特点
保持GAN的网络结构不变,将判别器网络最后的sigmoid删去。
将损失函数中的log删去。
每次更新判别器的参数,将参数绝对值截断到一个固定常数c。
不使用基于动量的优化算法(Adam),推荐使用RMSProp,SGD等方法。
WGAN图像分析
TensorFlow2.0实现
1 | import os |
模型运行结果
小技巧
- 图像输入可以先将其归一化到0-1之间或者-1-1之间,因为网络的参数一般都比较小,所以归一化后计算方便,收敛较快。
- 注意其中的一些维度变换和numpy,tensorflow常用操作,否则在阅读代码时可能会产生一些困难。
- 可以设置一些权重的保存方式,学习率的下降方式和早停方式。
- WGAN对于网络结构,优化器参数,网络层的一些超参数都是非常敏感的,效果不好不容易发现原因,这可能需要较多的工程实践经验。
- 先创建判别器,然后进行compile,这样判别器就固定了,然后创建生成器时,不要训练判别器,需要将判别器的trainable改成False,此时不会影响之前固定的判别器,这个可以通过模型的_collection_collected_trainable_weights属性查看,如果该属性为空,则模型不训练,否则模型可以训练,compile之后,该属性固定,无论后面如何修改trainable,只要不重新compile,都不影响训练。
- 本博客中的WGAN是在GAN的基础上进行修改,当然小伙伴们也可以尝试在DCGAN,CGAN等模型上进行尝试,可能一些超参数设置的不是非常合理,所以WGAN的效果不是特别好,小伙伴们在使用时可以自己修改。
WGAN小结
WGAN在提出时对网络的损失函数进行了大量的分析,引入W距离,Lipschitz常数等等,我不是大佬,也不对数学公式进行过多的阐述,可能我说了会让小伙伴们更加迷糊,因此有需要的小伙伴们可以去网上搜索相关资料。因为WGAN基本没有修改网络结构,因此网络参数和GAN完全相同。